from datetime import datetime, timedelta
import re

import numpy as np
import pytest

import pandas as pd
from pandas import Index, Timedelta, TimedeltaIndex, notna, timedelta_range
import pandas._testing as tm


class TestGetItem:
    def test_ellipsis(self):
        # GH#21282
        idx = timedelta_range("1 day", "31 day", freq="D", name="idx")

        result = idx[...]
        assert result.equals(idx)
        assert result is not idx

    def test_getitem_slice_keeps_name(self):
        # GH#4226
        tdi = timedelta_range("1d", "5d", freq="H", name="timebucket")
        assert tdi[1:].name == tdi.name

    def test_getitem(self):
        idx1 = timedelta_range("1 day", "31 day", freq="D", name="idx")

        for idx in [idx1]:
            result = idx[0]
            assert result == Timedelta("1 day")

            result = idx[0:5]
            expected = timedelta_range("1 day", "5 day", freq="D", name="idx")
            tm.assert_index_equal(result, expected)
            assert result.freq == expected.freq

            result = idx[0:10:2]
            expected = timedelta_range("1 day", "9 day", freq="2D", name="idx")
            tm.assert_index_equal(result, expected)
            assert result.freq == expected.freq

            result = idx[-20:-5:3]
            expected = timedelta_range("12 day", "24 day", freq="3D", name="idx")
            tm.assert_index_equal(result, expected)
            assert result.freq == expected.freq

            result = idx[4::-1]
            expected = TimedeltaIndex(
                ["5 day", "4 day", "3 day", "2 day", "1 day"], freq="-1D", name="idx"
            )
            tm.assert_index_equal(result, expected)
            assert result.freq == expected.freq

    @pytest.mark.parametrize(
        "key",
        [
            pd.Timestamp("1970-01-01"),
            pd.Timestamp("1970-01-02"),
            datetime(1970, 1, 1),
            pd.Timestamp("1970-01-03").to_datetime64(),
            # non-matching NA values
            np.datetime64("NaT"),
        ],
    )
    def test_timestamp_invalid_key(self, key):
        # GH#20464
        tdi = timedelta_range(0, periods=10)
        with pytest.raises(KeyError, match=re.escape(repr(key))):
            tdi.get_loc(key)


class TestGetLoc:
    def test_get_loc(self):
        idx = pd.to_timedelta(["0 days", "1 days", "2 days"])

        for method in [None, "pad", "backfill", "nearest"]:
            assert idx.get_loc(idx[1], method) == 1
            assert idx.get_loc(idx[1].to_pytimedelta(), method) == 1
            assert idx.get_loc(str(idx[1]), method) == 1

        assert idx.get_loc(idx[1], "pad", tolerance=Timedelta(0)) == 1
        assert idx.get_loc(idx[1], "pad", tolerance=np.timedelta64(0, "s")) == 1
        assert idx.get_loc(idx[1], "pad", tolerance=timedelta(0)) == 1

        with pytest.raises(ValueError, match="unit abbreviation w/o a number"):
            idx.get_loc(idx[1], method="nearest", tolerance="foo")

        with pytest.raises(ValueError, match="tolerance size must match"):
            idx.get_loc(
                idx[1],
                method="nearest",
                tolerance=[
                    Timedelta(0).to_timedelta64(),
                    Timedelta(0).to_timedelta64(),
                ],
            )

        for method, loc in [("pad", 1), ("backfill", 2), ("nearest", 1)]:
            assert idx.get_loc("1 day 1 hour", method) == loc

        # GH 16909
        assert idx.get_loc(idx[1].to_timedelta64()) == 1

        # GH 16896
        assert idx.get_loc("0 days") == 0

    def test_get_loc_nat(self):
        tidx = TimedeltaIndex(["1 days 01:00:00", "NaT", "2 days 01:00:00"])

        assert tidx.get_loc(pd.NaT) == 1
        assert tidx.get_loc(None) == 1
        assert tidx.get_loc(float("nan")) == 1
        assert tidx.get_loc(np.nan) == 1


class TestGetIndexer:
    def test_get_indexer(self):
        idx = pd.to_timedelta(["0 days", "1 days", "2 days"])
        tm.assert_numpy_array_equal(
            idx.get_indexer(idx), np.array([0, 1, 2], dtype=np.intp)
        )

        target = pd.to_timedelta(["-1 hour", "12 hours", "1 day 1 hour"])
        tm.assert_numpy_array_equal(
            idx.get_indexer(target, "pad"), np.array([-1, 0, 1], dtype=np.intp)
        )
        tm.assert_numpy_array_equal(
            idx.get_indexer(target, "backfill"), np.array([0, 1, 2], dtype=np.intp)
        )
        tm.assert_numpy_array_equal(
            idx.get_indexer(target, "nearest"), np.array([0, 1, 1], dtype=np.intp)
        )

        res = idx.get_indexer(target, "nearest", tolerance=Timedelta("1 hour"))
        tm.assert_numpy_array_equal(res, np.array([0, -1, 1], dtype=np.intp))


class TestWhere:
    def test_where_doesnt_retain_freq(self):
        tdi = timedelta_range("1 day", periods=3, freq="D", name="idx")
        cond = [True, True, False]
        expected = TimedeltaIndex([tdi[0], tdi[1], tdi[0]], freq=None, name="idx")

        result = tdi.where(cond, tdi[::-1])
        tm.assert_index_equal(result, expected)

    def test_where_invalid_dtypes(self):
        tdi = timedelta_range("1 day", periods=3, freq="D", name="idx")

        i2 = Index([pd.NaT, pd.NaT] + tdi[2:].tolist())

        msg = "value should be a 'Timedelta', 'NaT', or array of those"
        with pytest.raises(TypeError, match=msg):
            tdi.where(notna(i2), i2.asi8)

        with pytest.raises(TypeError, match=msg):
            tdi.where(notna(i2), i2 + pd.Timestamp.now())

        with pytest.raises(TypeError, match=msg):
            tdi.where(notna(i2), (i2 + pd.Timestamp.now()).to_period("D"))

        with pytest.raises(TypeError, match=msg):
            # non-matching scalar
            tdi.where(notna(i2), pd.Timestamp.now())

    def test_where_mismatched_nat(self):
        tdi = timedelta_range("1 day", periods=3, freq="D", name="idx")
        cond = np.array([True, False, False])

        msg = "value should be a 'Timedelta', 'NaT', or array of those"
        with pytest.raises(TypeError, match=msg):
            # wrong-dtyped NaT
            tdi.where(cond, np.datetime64("NaT", "ns"))


class TestTake:
    def test_take(self):
        # GH 10295
        idx1 = timedelta_range("1 day", "31 day", freq="D", name="idx")

        for idx in [idx1]:
            result = idx.take([0])
            assert result == Timedelta("1 day")

            result = idx.take([-1])
            assert result == Timedelta("31 day")

            result = idx.take([0, 1, 2])
            expected = timedelta_range("1 day", "3 day", freq="D", name="idx")
            tm.assert_index_equal(result, expected)
            assert result.freq == expected.freq

            result = idx.take([0, 2, 4])
            expected = timedelta_range("1 day", "5 day", freq="2D", name="idx")
            tm.assert_index_equal(result, expected)
            assert result.freq == expected.freq

            result = idx.take([7, 4, 1])
            expected = timedelta_range("8 day", "2 day", freq="-3D", name="idx")
            tm.assert_index_equal(result, expected)
            assert result.freq == expected.freq

            result = idx.take([3, 2, 5])
            expected = TimedeltaIndex(["4 day", "3 day", "6 day"], name="idx")
            tm.assert_index_equal(result, expected)
            assert result.freq is None

            result = idx.take([-3, 2, 5])
            expected = TimedeltaIndex(["29 day", "3 day", "6 day"], name="idx")
            tm.assert_index_equal(result, expected)
            assert result.freq is None

    def test_take_invalid_kwargs(self):
        idx = timedelta_range("1 day", "31 day", freq="D", name="idx")
        indices = [1, 6, 5, 9, 10, 13, 15, 3]

        msg = r"take\(\) got an unexpected keyword argument 'foo'"
        with pytest.raises(TypeError, match=msg):
            idx.take(indices, foo=2)

        msg = "the 'out' parameter is not supported"
        with pytest.raises(ValueError, match=msg):
            idx.take(indices, out=indices)

        msg = "the 'mode' parameter is not supported"
        with pytest.raises(ValueError, match=msg):
            idx.take(indices, mode="clip")

    # TODO: This method came from test_timedelta; de-dup with version above
    def test_take2(self):
        tds = ["1day 02:00:00", "1 day 04:00:00", "1 day 10:00:00"]
        idx = timedelta_range(start="1d", end="2d", freq="H", name="idx")
        expected = TimedeltaIndex(tds, freq=None, name="idx")

        taken1 = idx.take([2, 4, 10])
        taken2 = idx[[2, 4, 10]]

        for taken in [taken1, taken2]:
            tm.assert_index_equal(taken, expected)
            assert isinstance(taken, TimedeltaIndex)
            assert taken.freq is None
            assert taken.name == expected.name

    def test_take_fill_value(self):
        # GH 12631
        idx = TimedeltaIndex(["1 days", "2 days", "3 days"], name="xxx")
        result = idx.take(np.array([1, 0, -1]))
        expected = TimedeltaIndex(["2 days", "1 days", "3 days"], name="xxx")
        tm.assert_index_equal(result, expected)

        # fill_value
        result = idx.take(np.array([1, 0, -1]), fill_value=True)
        expected = TimedeltaIndex(["2 days", "1 days", "NaT"], name="xxx")
        tm.assert_index_equal(result, expected)

        # allow_fill=False
        result = idx.take(np.array([1, 0, -1]), allow_fill=False, fill_value=True)
        expected = TimedeltaIndex(["2 days", "1 days", "3 days"], name="xxx")
        tm.assert_index_equal(result, expected)

        msg = (
            "When allow_fill=True and fill_value is not None, "
            "all indices must be >= -1"
        )
        with pytest.raises(ValueError, match=msg):
            idx.take(np.array([1, 0, -2]), fill_value=True)
        with pytest.raises(ValueError, match=msg):
            idx.take(np.array([1, 0, -5]), fill_value=True)

        msg = "index -5 is out of bounds for (axis 0 with )?size 3"
        with pytest.raises(IndexError, match=msg):
            idx.take(np.array([1, -5]))
